import torch
import argparse
from tqdm import tqdm
from torchmetrics import R2Score
import datetime
from src.dataset import PPDataset
from torch.utils.tensorboard import SummaryWriter
from test import ModelAnalyzer, DistanceLoss, BCELoss_simple, CELoss, WeighedBCELoss, WeighedBCELossV2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dt = datetime.datetime.now()
dt_str = dt.strftime("%y%m%d%H%M%S")

# 训练参数控制
parser = argparse.ArgumentParser(description='PyTorch model trainer')
parser.add_argument('--model', default="unetstft2", help='name of model')
parser.add_argument('--pre', default="fcy2",
                    help='name of dataset preprocess method')
parser.add_argument('--lr', type=float, default=0.0005,
                    help='learn rate of optimizer')
parser.add_argument('--epoch_size', type=int, default=10,
                    help='how much epoch to train')
parser.add_argument('--batch_size', type=int, default=32,
                    help='number of waveforms in each batch')
parser.add_argument('--data_len', type=int, default=4000,
                    help='samples in each piece of data')
parser.add_argument('--ds_path', default="E:/RealSeisData/Diting50hz/",
                    help='name of dataset preprocess method')

if __name__ == "__main__":
    # 读取参数
    args, unknown = parser.parse_known_args()
    model_name = 'src.model_'+args.model
    pre_method = args.pre
    dataset_path = args.ds_path
    lr = args.lr
    lr_final = lr/50
    batch_size = args.batch_size
    epoch_size = args.epoch_size
    dlen = args.data_len

    # 参数输出
    for arg in vars(args):
        print(format(arg, '<15'), format(
            str(getattr(args, arg)), '<'))
    print("device:", device)

    writer = SummaryWriter(log_dir="./logs/{}/{}".format(model_name, dt_str))
    # 取得数据集
    print('loading dataset...')
    train_dataset = PPDataset(
        dataset_path, "DiTing330km_train.csv", methodmame=pre_method, dlen=dlen)
    val_dataset = PPDataset(
        dataset_path, "DiTing330km_validation.csv", methodmame=pre_method, dlen=dlen)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, num_workers=8)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size*8, num_workers=4)

    # 导入并定义模型，优化算法
    modelpkg = __import__(model_name, fromlist=['Model'])
    model = modelpkg.Model().to(device)
    # 加载预训练模型
    model_pretrain = torch.load("./model/unetstft2_231101154751_ep16.pth")
    model.load_state_dict(model_pretrain.state_dict())
    # 优化器
    # opt = torch.optim.SGD(model.parameters(), lr=lr)
    opt = torch.optim.AdamW(
        model.parameters(), lr=lr, weight_decay=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt, epoch_size, lr_final)

    # 指定评价指标
    # loss_MAE = torch.nn.L1Loss()
    # loss_SmoothMAE = torch.nn.SmoothL1Loss()
    loss_R2 = R2Score(num_outputs=6)
    # loss_BCE = torch.nn.BCELoss(reduction='none')
    loss_CE = CELoss()
    loss_BCE = BCELoss_simple()
    loss_Distance = DistanceLoss()
    loss_WeighedBCE = WeighedBCELossV2(w=0.8)

    # P S N 损失权重
    # w = torch.ones(1, 2, 1).to(device)
    # w[0, 0, 0] = 0.4 #P
    # w[0, 1, 0] = 0.6 #S
    # w=w * 2 # let mean == 1
    analyzer = ModelAnalyzer(
        'DiTing330km_validation.csv', 0.2, 25, 'fcy2', dlen=dlen)

    # 批量训练
    for epoch in range(epoch_size):
        print('\n[ epoch {} ]'.format(epoch + 1))
        loss_sum = 0
        model.train()
        pbar = tqdm(train_loader, ncols=0, mininterval=1)
        pbar.set_description("train")
        for train_x, train_label in pbar:
            train_x = train_x.to(device)
            train_label = train_label.to(device)
            predict_y = model(train_x.float())[:, :2]
            # loss1 = (loss_BCE(predict_y, train_label.float())*w).mean()
            loss1 = loss_WeighedBCE(predict_y, train_label.float())
            # loss1 = loss_BCE(predict_y, train_label.float()).mean()
            loss2 = loss_Distance(predict_y, train_label.float())
            loss = loss1 #+ loss2  # 最终loss
            opt.zero_grad()
            loss.backward()
            opt.step()
            pbar.set_postfix(
                loss1='{:.6f}'.format(loss1.item()),
                loss2='{:.6f}'.format(loss2.item()), refresh=False)
            loss_sum += loss1.item()
        print('train bce loss: {:.6f}'.format(loss_sum/len(train_loader)))
        # 每次训练迭代后，保存当前模型参数
        # The model output location is placed under /model
        model_fname = './model/{}_{}_ep{}.pth'.format(
            args.model, dt_str, epoch + 1)
        torch.save(
            model, model_fname)
        print(model_fname)
        scheduler.step()
        print('learn_rate:', scheduler.get_last_lr()[0])

        # 每次训练迭代后，使用validation数据评估模型准确率
        # loss1_sum = 0
        # model.eval()
        # pbar = tqdm(val_loader, ncols=0, mininterval=1)
        # pbar.set_description("validation")
        # for val_x, val_label in pbar:
        #     val_x = val_x.to(device)
        #     val_label = val_label.to(device)
        #     predict_y = model(val_x.float()).detach().to(device)[:, :2]
        #     loss1 = loss_BCE(predict_y, val_label.float())
        #     pbar.set_postfix(
        #         loss1='{:.6f}'.format(loss1.item()), refresh=False)
        #     loss1_sum += loss1.item()
        # print('vali loss1: {:.6f}'.format(loss1_sum/len(val_loader)))

        p_recall, p_precision, p_f1, s_recall, s_precision, s_f1 = analyzer.analyse_findpeak(
            model_fname)

        writer.add_scalar('learn_rate', scheduler.get_last_lr()[0], epoch + 1)
        writer.add_scalar('Loss/train', loss_sum/len(train_loader), epoch + 1)
        # writer.add_scalar('Loss/val', loss1_sum/len(val_loader), epoch)
        writer.add_scalar('Recall/p', p_recall, epoch + 1)
        writer.add_scalar('Precision/p', p_precision, epoch + 1)
        writer.add_scalar('F1/p', p_f1, epoch + 1)
        writer.add_scalar('Recall/s', s_recall, epoch + 1)
        writer.add_scalar('Precision/s', s_precision, epoch + 1)
        writer.add_scalar('F1/s', s_f1, epoch + 1)
